import argparse
import os
import random
import sys
from datetime import datetime
from importlib import import_module

import numpy as np
import torch
from fastargs import Param, Section, get_current_config
from fastargs.decorators import param
from fastargs.validation import BoolAsInt, File, Folder, OneOf

sys.path.append("src")

Section("overall", "Overall configs").params(
    model_name=Param(str, required=True, desc="Model name"),
    logger=Param(OneOf(["json", "none"]), default="none", desc="Logger to use"),
    cache_dir=Param(Folder(True), default=".cache", desc="Cache directory"),
    seed=Param(int, default=0, desc="Random seed"),
)
Section("pruner", "Pruning configs").params(
    sparsity_ratio=Param(float, default=0.0, desc="Pruning ratio"),
    scores=Param(
        OneOf(["weight", "wanda", "hessianfree", "random", "llmpruner_1", "hessian"]),
        default="weight",
        desc="Scores to use",
    ),
    sparsity_type=Param(
        OneOf(["unstructured", "structured"]),
        default="unstructured",
        desc="Sparsity type",
    ),
    n_samples=Param(int, default=128, desc="Number of samples to use for WANDA"),
    dataset_name=Param(str, default="c4", desc="Dataset name"),
    eval_dataset_name=Param(str, default="c4", desc="Dataset name"),
)
Section("pruner.structured", "Structured pruning configs").enable_if(
    lambda cfg: cfg["pruner.sparsity_type"] == "structured"
).params(
    prune_n=Param(int, default=2, desc="Number of heads to prune"),
    prune_m=Param(int, default=4, desc="Number of neurons to prune"),
)
Section("trainer", "Trainer configs").params(
    recovery=Param(BoolAsInt(), default=False, desc="Whether to use self-training"),
    epochs=Param(int, default=1, desc="Number of epochs to train"),
    lr=Param(float, default=1e-4, desc="Learning rate"),
    num_warmup_steps=Param(int, default=0, desc="Number of warmup steps"),
    dataset_name=Param(str, default="c4", desc="Dataset name"),
    sparse_training=Param(
        BoolAsInt(), default=True, desc="Whether to use sparse training"
    ),
    batch_size=Param(int, default=16, desc="Batch size"),
)
Section("logger", "General logger configs").params(
    name=Param(
        str,
        default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"),
        desc="Name of this run",
    ),
)

Section("logger.json", "JSON logger").enable_if(
    lambda cfg: cfg["overall.logger"] == "json"
).params(
    root=Param(Folder(True), default="files/logs", desc="Path to log folder"),
)


class Main:
    def __init__(self) -> None:
        self.make_config()
        self.setup_seed()
        self.init_model()
        self.init_pruner()
        self.init_logger()
        self.run()

    def make_config(self, quiet=False):
        self.config = get_current_config()
        parser = argparse.ArgumentParser("LLM pruning")
        self.config.augment_argparse(parser)
        self.config.collect_argparse_args(parser)

        self.config.validate()
        if not quiet:
            self.config.summary()

    @param("overall.seed")
    def setup_seed(self, seed: int):
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    @param("overall.model_name")
    def init_model(self, model_name):
        kwargs = self.config.get_section(f"overall")
        kwargs.update(self.config.get_section(f"overall"))
        kwargs.update(self.config.get_section(f"trainer"))
        self.model = import_module(f"model.base").get(**kwargs)

    @param("pruner.sparsity_type")
    def init_pruner(self, sparsity_type):
        kwargs = self.config.get_section(f"pruner")
        kwargs.update(self.config.get_section(f"pruner.{sparsity_type}"))
        self.pruner = import_module(f"pruner.{sparsity_type}").get(**kwargs)

    @param("overall.logger")
    def init_logger(self, logger):
        kwargs = self.config.get_section(f"logger")
        kwargs.update(self.config.get_section(f"logger.{logger}"))
        kwargs["config"] = self.config.get_all_config()
        self.logger = import_module(f"loggers.{logger}_").get(**kwargs)

    def run(self):
        self.model.prune(self.pruner, self.logger)
        self.model.save(self.logger)
        self.model.recover(self.logger)
        self.model.eval(self.logger)


if __name__ == "__main__":
    Main()
